{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tTLuI6Xz4yvr"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "i5R7Gb120Ekx"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XPflk8Ey7e9D"
   },
   "source": [
    "# Writing a training loop from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0pA8D5OrHGIx"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/keras-team/keras-io/blob/master/tf/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/keras-team/keras-io/blob/master/guides/writing_a_training_loop_from_scratch.py\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/keras-io/tf/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "l7oVQp2POvp8"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "jkWTZxIyysDm"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TcNr4V6Z2eDw"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "Keras provides default training and evaluation loops, `fit()` and `evaluate()`.\n",
    "Their usage is coverered in the guide\n",
    "[Training & evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate/).\n",
    "\n",
    "If you want to customize the learning algorithm of your model while still leveraging\n",
    "the convenience of `fit()`\n",
    "(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and\n",
    "implement your own `train_step()` method, which\n",
    "is called repeatedly during `fit()`. This is covered in the guide\n",
    "[Customizing what happens in `fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit/).\n",
    "\n",
    "Now, if you want very low-level control over training & evaluation, you should write\n",
    "your own training & evaluation loops from scratch. This is what this guide is about."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MvWJHarCmdOp"
   },
   "source": [
    "## Using the `GradientTape`: a first end-to-end example\n",
    "\n",
    "Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of\n",
    "the trainable weights of the layer with respect to a loss value. Using an optimizer\n",
    "instance, you can use these gradients to update these variables (which you can\n",
    "retrieve using `model.trainable_weights`).\n",
    "\n",
    "Let's consider a simple MNIST model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "Xd8GpFudzEiU"
   },
   "outputs": [],
   "source": [
    "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
    "x1 = layers.Dense(64, activation=\"relu\")(inputs)\n",
    "x2 = layers.Dense(64, activation=\"relu\")(x1)\n",
    "outputs = layers.Dense(10, name=\"predictions\")(x2)\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9uXwjRTjgQKE"
   },
   "source": [
    "Let's train it using mini-batch gradient with a custom training loop.\n",
    "\n",
    "First, we're going to need an optimizer, a loss function, and a dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "xnxsNbPmMvqr"
   },
   "outputs": [],
   "source": [
    "# Instantiate an optimizer.\n",
    "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
    "# Instantiate a loss function.\n",
    "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "\n",
    "# Prepare the training dataset.\n",
    "batch_size = 64\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "x_train = np.reshape(x_train, (-1, 784))\n",
    "x_test = np.reshape(x_train, (-1, 784))\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VmZj17qo05qZ"
   },
   "source": [
    "Here's our training loop:\n",
    "\n",
    "- We open a `for` loop that iterates over epochs\n",
    "- For each epoch, we open a `for` loop that iterates over the dataset, in batches\n",
    "- For each batch, we open a `GradientTape()` scope\n",
    "- Inside this scope, we call the model (forward pass) and compute the loss\n",
    "- Outside the scope, we retrieve the gradients of the weights\n",
    "of the model with regard to the loss\n",
    "- Finally, we use the optimizer to update the weights of the model based on the\n",
    "gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "KP7wZij8LJyh"
   },
   "outputs": [],
   "source": [
    "epochs = 2\n",
    "for epoch in range(epochs):\n",
    "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
    "\n",
    "    # Iterate over the batches of the dataset.\n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "\n",
    "        # Open a GradientTape to record the operations run\n",
    "        # during the forward pass, which enables autodifferentiation.\n",
    "        with tf.GradientTape() as tape:\n",
    "\n",
    "            # Run the forward pass of the layer.\n",
    "            # The operations that the layer applies\n",
    "            # to its inputs are going to be recorded\n",
    "            # on the GradientTape.\n",
    "            logits = model(x_batch_train, training=True)  # Logits for this minibatch\n",
    "\n",
    "            # Compute the loss value for this minibatch.\n",
    "            loss_value = loss_fn(y_batch_train, logits)\n",
    "\n",
    "        # Use the gradient tape to automatically retrieve\n",
    "        # the gradients of the trainable variables with respect to the loss.\n",
    "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "\n",
    "        # Run one step of gradient descent by updating\n",
    "        # the value of the variables to minimize the loss.\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "\n",
    "        # Log every 200 batches.\n",
    "        if step % 200 == 0:\n",
    "            print(\n",
    "                \"Training loss (for one batch) at step %d: %.4f\"\n",
    "                % (step, float(loss_value))\n",
    "            )\n",
    "            print(\"Seen so far: %s samples\" % ((step + 1) * 64))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vH60dyqVSDTX"
   },
   "source": [
    "## Low-level handling of metrics\n",
    "\n",
    "Let's add metrics monitoring to this basic loop.\n",
    "\n",
    "You can readily reuse the built-in metrics (or custom ones you wrote) in such training\n",
    "loops written from scratch. Here's the flow:\n",
    "\n",
    "- Instantiate the metric at the start of the loop\n",
    "- Call `metric.update_state()` after each batch\n",
    "- Call `metric.result()` when you need to display the current value of the metric\n",
    "- Call `metric.reset_states()` when you need to clear the state of the metric\n",
    "(typically at the end of an epoch)\n",
    "\n",
    "Let's use this knowledge to compute `SparseCategoricalAccuracy` on validation data at\n",
    "the end of each epoch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "JUriDlmmzLey"
   },
   "outputs": [],
   "source": [
    "# Get model\n",
    "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
    "x = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
    "x = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
    "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "# Instantiate an optimizer to train the model.\n",
    "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
    "# Instantiate a loss function.\n",
    "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "\n",
    "# Prepare the metrics.\n",
    "train_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
    "val_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
    "\n",
    "# Prepare the training dataset.\n",
    "batch_size = 64\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
    "\n",
    "# Prepare the validation dataset.\n",
    "# Reserve 10,000 samples for validation.\n",
    "x_val = x_train[-10000:]\n",
    "y_val = y_train[-10000:]\n",
    "x_train = x_train[:-10000]\n",
    "y_train = y_train[:-10000]\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
    "val_dataset = val_dataset.batch(64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IpAbe0hltQSv"
   },
   "source": [
    "Here's our training & evaluation loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "XYep4rMxu9D1"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "epochs = 2\n",
    "for epoch in range(epochs):\n",
    "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Iterate over the batches of the dataset.\n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            logits = model(x_batch_train, training=True)\n",
    "            loss_value = loss_fn(y_batch_train, logits)\n",
    "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "\n",
    "        # Update training metric.\n",
    "        train_acc_metric.update_state(y_batch_train, logits)\n",
    "\n",
    "        # Log every 200 batches.\n",
    "        if step % 200 == 0:\n",
    "            print(\n",
    "                \"Training loss (for one batch) at step %d: %.4f\"\n",
    "                % (step, float(loss_value))\n",
    "            )\n",
    "            print(\"Seen so far: %d samples\" % ((step + 1) * 64))\n",
    "\n",
    "    # Display metrics at the end of each epoch.\n",
    "    train_acc = train_acc_metric.result()\n",
    "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
    "\n",
    "    # Reset training metrics at the end of each epoch\n",
    "    train_acc_metric.reset_states()\n",
    "\n",
    "    # Run a validation loop at the end of each epoch.\n",
    "    for x_batch_val, y_batch_val in val_dataset:\n",
    "        val_logits = model(x_batch_val, training=False)\n",
    "        # Update val metrics\n",
    "        val_acc_metric.update_state(y_batch_val, val_logits)\n",
    "    val_acc = val_acc_metric.result()\n",
    "    val_acc_metric.reset_states()\n",
    "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
    "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uIzUVMFmKOL6"
   },
   "source": [
    "## Speeding-up your training step with `tf.function`\n",
    "\n",
    "The default runtime in TensorFlow 2.0 is\n",
    "[eager execution](https://www.tensorflow.org/guide/eager). As such, our training loop\n",
    "above executes eagerly.\n",
    "\n",
    "This is great for debugging, but graph compilation has a definite performance\n",
    "advantage. Decribing your computation as a static graph enables the framework\n",
    "to apply global performance optimizations. This is impossible when\n",
    "the framework is constrained to greedly execute one operation after another,\n",
    "with no knowledge of what comes next.\n",
    "\n",
    "You can compile into a static graph any function that take tensors as input.\n",
    "Just add a `@tf.function` decorator on it, like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "rdU2jDOB4itB"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(x, y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(x, training=True)\n",
    "        loss_value = loss_fn(y, logits)\n",
    "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "    train_acc_metric.update_state(y, logits)\n",
    "    return loss_value\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SpPlBEHjXg9c"
   },
   "source": [
    "Let's do the same with the evaluation step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "2yGMv9LPrvYl"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def test_step(x, y):\n",
    "    val_logits = model(x, training=False)\n",
    "    val_acc_metric.update_state(y, val_logits)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y7Elq9Kky2HQ"
   },
   "source": [
    "Now, let's re-run our training loop with this compiled training step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "jb12E4zoBDOy"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "epochs = 2\n",
    "for epoch in range(epochs):\n",
    "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Iterate over the batches of the dataset.\n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        loss_value = train_step(x_batch_train, y_batch_train)\n",
    "\n",
    "        # Log every 200 batches.\n",
    "        if step % 200 == 0:\n",
    "            print(\n",
    "                \"Training loss (for one batch) at step %d: %.4f\"\n",
    "                % (step, float(loss_value))\n",
    "            )\n",
    "            print(\"Seen so far: %d samples\" % ((step + 1) * 64))\n",
    "\n",
    "    # Display metrics at the end of each epoch.\n",
    "    train_acc = train_acc_metric.result()\n",
    "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
    "\n",
    "    # Reset training metrics at the end of each epoch\n",
    "    train_acc_metric.reset_states()\n",
    "\n",
    "    # Run a validation loop at the end of each epoch.\n",
    "    for x_batch_val, y_batch_val in val_dataset:\n",
    "        test_step(x_batch_val, y_batch_val)\n",
    "\n",
    "    val_acc = val_acc_metric.result()\n",
    "    val_acc_metric.reset_states()\n",
    "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
    "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kG0YiXFj519D"
   },
   "source": [
    "Much faster, isn't it?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LtECMpVqBAl8"
   },
   "source": [
    "## Low-level handling of losses tracked by the model\n",
    "\n",
    "Layers & models recursively track any losses created during the forward pass\n",
    "by layers that call `self.add_loss(value)`. The resulting list of scalar loss\n",
    "values are available via the property `model.losses`\n",
    "at the end of the forward pass.\n",
    "\n",
    "If you want to be using these loss components, you should sum them\n",
    "and add them to the main loss in your training step.\n",
    "\n",
    "Consider this layer, that creates an activity regularization loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "EfiB6akb2kBT"
   },
   "outputs": [],
   "source": [
    "class ActivityRegularizationLayer(layers.Layer):\n",
    "    def call(self, inputs):\n",
    "        self.add_loss(1e-2 * tf.reduce_sum(inputs))\n",
    "        return inputs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FuaN95bX7txO"
   },
   "source": [
    "Let's build a really simple model that uses it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "u8JIjdc6sWRo"
   },
   "outputs": [],
   "source": [
    "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
    "x = layers.Dense(64, activation=\"relu\")(inputs)\n",
    "# Insert activity regularization as a layer\n",
    "x = ActivityRegularizationLayer()(x)\n",
    "x = layers.Dense(64, activation=\"relu\")(x)\n",
    "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
    "\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xWEGiFOAg2Pw"
   },
   "source": [
    "Here's what our training step should look like now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "7dDTCiM3ikkt"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(x, y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(x, training=True)\n",
    "        loss_value = loss_fn(y, logits)\n",
    "        # Add any extra losses created during the forward pass.\n",
    "        loss_value += sum(model.losses)\n",
    "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "    train_acc_metric.update_state(y, logits)\n",
    "    return loss_value\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ht7Utep1WEhg"
   },
   "source": [
    "## Summary\n",
    "\n",
    "Now you know everything there is to know about using built-in training loops and\n",
    "writing your own from scratch.\n",
    "\n",
    "To conclude, here's a simple end-to-end example that ties together everything\n",
    "you've learned in this guide: a DCGAN trained on MNIST digits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g1HCZDicyRMP"
   },
   "source": [
    "## End-to-end example: a GAN training loop from scratch\n",
    "\n",
    "You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new\n",
    "images that look almost real, by learning the latent distribution of a training\n",
    "dataset of images (the \"latent space\" of the images).\n",
    "\n",
    "A GAN is made of two parts: a \"generator\" model that maps points in the latent\n",
    "space to points in image space, an a \"discriminator\" model, a classifier\n",
    "that can tell the difference between real imagees (from the training dataset)\n",
    "and fake images (the output of the generator network).\n",
    "\n",
    "A GAN training loop looks like this:\n",
    "\n",
    "1) Train the discriminator.\n",
    "- Sample a batch of random points in the latent space.\n",
    "- Turn the points into fake images via the \"generator\" model.\n",
    "- Get a batch of real images and combine them with the generated images.\n",
    "- Train the \"discriminator\" model to classify generated vs. real images.\n",
    "\n",
    "2) Train the generator.\n",
    "- Sample random points in the latent space.\n",
    "- Turn the points into fake images via the \"generator\" network.\n",
    "- Get a batch of real images and combine them with the generated images.\n",
    "- Train the \"generator\" model to \"fool\" the discriminator and classify the fake images\n",
    "as real.\n",
    "\n",
    "For a much more detailed overview of how GANs works, see\n",
    "[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).\n",
    "\n",
    "Let's implement this training loop. First, create the discriminator meant to classify\n",
    "fake vs real digits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "oo0K12W14lTB"
   },
   "outputs": [],
   "source": [
    "discriminator = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(28, 28, 1)),\n",
    "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.GlobalMaxPooling2D(),\n",
    "        layers.Dense(1),\n",
    "    ],\n",
    "    name=\"discriminator\",\n",
    ")\n",
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CtDX25Xojcan"
   },
   "source": [
    "Then let's create a generator network,\n",
    "that turns latent vectors into outputs of shape `(28, 28, 1)` (representing\n",
    "MNIST digits):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "y01N2V7Ykamk"
   },
   "outputs": [],
   "source": [
    "latent_dim = 128\n",
    "\n",
    "generator = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(latent_dim,)),\n",
    "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
    "        layers.Dense(7 * 7 * 128),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Reshape((7, 7, 128)),\n",
    "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
    "    ],\n",
    "    name=\"generator\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1D0KbjVFlyye"
   },
   "source": [
    "Here's the key bit: the training loop. As you can see it is quite straightforward. The\n",
    "training step function only takes 17 lines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "MGbCyv5MNPme"
   },
   "outputs": [],
   "source": [
    "# Instantiate one optimizer for the discriminator and another for the generator.\n",
    "d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)\n",
    "g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)\n",
    "\n",
    "# Instantiate a loss function.\n",
    "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def train_step(real_images):\n",
    "    # Sample random points in the latent space\n",
    "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
    "    # Decode them to fake images\n",
    "    generated_images = generator(random_latent_vectors)\n",
    "    # Combine them with real images\n",
    "    combined_images = tf.concat([generated_images, real_images], axis=0)\n",
    "\n",
    "    # Assemble labels discriminating real from fake images\n",
    "    labels = tf.concat(\n",
    "        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0\n",
    "    )\n",
    "    # Add random noise to the labels - important trick!\n",
    "    labels += 0.05 * tf.random.uniform(labels.shape)\n",
    "\n",
    "    # Train the discriminator\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = discriminator(combined_images)\n",
    "        d_loss = loss_fn(labels, predictions)\n",
    "    grads = tape.gradient(d_loss, discriminator.trainable_weights)\n",
    "    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))\n",
    "\n",
    "    # Sample random points in the latent space\n",
    "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
    "    # Assemble labels that say \"all real images\"\n",
    "    misleading_labels = tf.zeros((batch_size, 1))\n",
    "\n",
    "    # Train the generator (note that we should *not* update the weights\n",
    "    # of the discriminator)!\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = discriminator(generator(random_latent_vectors))\n",
    "        g_loss = loss_fn(misleading_labels, predictions)\n",
    "    grads = tape.gradient(g_loss, generator.trainable_weights)\n",
    "    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))\n",
    "    return d_loss, g_loss, generated_images\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UwFhElNElUpn"
   },
   "source": [
    "Let's train our GAN, by repeatedly calling `train_step` on batches of images.\n",
    "\n",
    "Since our discriminator and generator are convnets, you're going to want to\n",
    "run this code on a GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code",
    "id": "TfCAGhKdDEVa"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Prepare the dataset. We use both the training & test MNIST digits.\n",
    "batch_size = 64\n",
    "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
    "all_digits = np.concatenate([x_train, x_test])\n",
    "all_digits = all_digits.astype(\"float32\") / 255.0\n",
    "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
    "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
    "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
    "\n",
    "epochs = 1  # In practice you need at least 20 epochs to generate nice digits.\n",
    "save_dir = \"./\"\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(\"\\nStart epoch\", epoch)\n",
    "\n",
    "    for step, real_images in enumerate(dataset):\n",
    "        # Train the discriminator & generator on one batch of real images.\n",
    "        d_loss, g_loss, generated_images = train_step(real_images)\n",
    "\n",
    "        # Logging.\n",
    "        if step % 200 == 0:\n",
    "            # Print metrics\n",
    "            print(\"discriminator loss at step %d: %.2f\" % (step, d_loss))\n",
    "            print(\"adversarial loss at step %d: %.2f\" % (step, g_loss))\n",
    "\n",
    "            # Save one generated image\n",
    "            img = tf.keras.preprocessing.image.array_to_img(\n",
    "                generated_images[0] * 255.0, scale=False\n",
    "            )\n",
    "            img.save(os.path.join(save_dir, \"generated_img\" + str(step) + \".png\"))\n",
    "\n",
    "        # To limit execution time we stop after 10 steps.\n",
    "        # Remove the lines below to actually train the model!\n",
    "        if step > 10:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RF070l5Ik9xo"
   },
   "source": [
    "That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the\n",
    "Colab GPU."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "writing_a_training_loop_from_scratch.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}